# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------

import math
import torch
from torch import nn
from .module.modules import MLP

from util.misc import (NestedTensor, nested_tensor_from_tensor_list, inverse_sigmoid)

from .backbones.backbone import build_backbone
from .necks.conditional_transformer import build_transformer
from .matcher.matcher import build_matcher
from .losses.setcriterion import SetCriterion
from .postprocess.postprocess import PostProcess


def build(args):
    device = torch.device(args.device)
    # the `num_classes` naming here is somewhat misleading.
    # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
    # is the maximum id for a class in your dataset. For example,
    # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
    # As another example, for a dataset that has a single class with id 1,
    # you should pass `num_classes` to be 2 (max_obj_id + 1).
    # For more details on this, check the following discussion
    # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        # for panoptic, we just add a num_classes that is large enough to hold
        # max_obj_id + 1, but the exact value doesn't really matter
        num_classes = 250

    backbone = build_backbone(args)

    transformer = build_transformer(args)

    model = ConditionalDETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
    )
    #if args.masks:
    #    model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
    weight_dict = get_weight_dict(args)
    losses = get_losses(args)
    matcher = build_matcher(args, 'focal_loss')
    criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
                             focal_alpha=args.focal_alpha, losses=losses, loss_type='focal_loss')
    criterion.to(device)
    postprocessors = {'bbox': PostProcess(trans_type='ddetr')}
    # if args.masks:
    #     postprocessors['segm'] = PostProcessSegm()
    #     if args.dataset_file == "coco_panoptic":
    #         is_thing_map = {i: i <= 90 for i in range(201)}
    #         postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)

    return model, criterion, postprocessors




class ConditionalDETR(nn.Module):
    """ This is the Conditional DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, num_verb_classes=0):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.hidden_dim = hidden_dim
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1)
        nn.init.xavier_uniform_(self.input_proj.weight, gain=1)
        nn.init.constant_(self.input_proj.bias, 0)

        self.build_head(num_classes, num_verb_classes)


    def build_head(self, num_classes, num_verb_classes):
        hidden_dim = self.hidden_dim
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        
        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)


    def forward(self, samples: NestedTensor, targets=None):
        srcs, masks, poses = self.backbone_forward(samples)
        outputs = self.neck_forward(srcs, masks, poses, self.query_embed.weight) # hs, memory
        out = self.head_forward(outputs)
        out.update(outputs)
        return out

    def backbone_forward(self, samples: NestedTensor):
        if not isinstance(samples, NestedTensor):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)
        srcs = []
        masks = []
        poses = []
        for j, feat in enumerate(features):
            #if j== len(features)-1:
                src, mask = feat.decompose()
                srcs.append(self.input_proj(src))
                masks.append(mask)
                poses.append(pos[j])
                assert mask is not None

        return srcs, masks, poses


    def neck_forward(self, srcs, masks, poses, query_embed):
        outputs = self.transformer(srcs, masks, query_embed, poses)
        return outputs



    def head_forward(self, outputs):
        x = outputs['out_query']
        reference = outputs['reference']
        reference_before_sigmoid = inverse_sigmoid(reference)
        outputs_coords = []
        for lvl in range(x.shape[0]):
            tmp = self.bbox_embed(x[lvl])
            tmp[..., :2] += reference_before_sigmoid
            outputs_coord = tmp.sigmoid()
            outputs_coords.append(outputs_coord)
        outputs_coord = torch.stack(outputs_coords)

        outputs_class = self.class_embed(x)
        out = {'pred_logits': outputs_class[-1], 
               'pred_boxes': outputs_coord[-1]} 
        
        if self.aux_loss and self.training:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out



    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]


def get_weight_dict(args):
    weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    if args.masks:
        weight_dict["loss_mask"] = args.mask_loss_coef
        weight_dict["loss_dice"] = args.dice_loss_coef
    # TODO this is a hack
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
        aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)
    return weight_dict

def get_losses(args):
    losses = ['labels', 'boxes', 'cardinality']
    if args.masks:
        losses += ["masks"]
    return losses